'''
Created on May 21, 2018

@author: helrewaidy
'''
# sub-parts of the U-Net model

import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from complexnet.cmplxconv import ComplexConv2d, ComplexConv3d
from complexnet.radialbn2 import RadialBatchNorm2d, RadialBatchNorm3d
from complexnet.cmplxupsample import ComplexUpsample
from complexnet.cmplxdropout import ComplexDropout2d

from parameters import Parameters
params = Parameters()

def Activation(*args):
    if params.activation_func == 'CReLU':
        return nn.ReLU(inplace=True)
    elif params.activation_func == 'CLeakyeak':
        return nn.LeakyReLU(negative_slope=0.1, inplace=True)


if params.network_type == '2D':
    ComplexConv = ComplexConv2d
    RadialBatchNorm = RadialBatchNorm2d
elif params.network_type == '3D':
    ComplexConv = ComplexConv3d
    RadialBatchNorm = RadialBatchNorm3d


class double_conv(nn.Module):
    '''(conv => ReLU => BN) * 2'''

    def __init__(self, in_ch, out_ch):
        super(double_conv, self).__init__()
        self.conv = nn.Sequential(
            ComplexConv(in_ch, out_ch, 3, padding=1),
            RadialBatchNorm(out_ch),  # ComplexBatchNormalize(out_ch),
            Activation(out_ch),  # nn.ReLU(inplace=True),
            ComplexDropout2d(params.dropout_ratio),
            ComplexConv(out_ch, out_ch, 3, padding=1),
            RadialBatchNorm(out_ch),  # ComplexBatchNormalize(out_ch)
            Activation(out_ch),  # nn.ReLU(inplace=True),
            ComplexDropout2d(params.dropout_ratio)
        )

    def forward(self, x):
        x = self.conv(x)
        return x


class down_conv(nn.Module):
    def __init__(self, in_ch):
        super(down_conv, self).__init__()
        down_stride = (2, 2, 1) if params.network_type == '3D' else (2, 2)
        self.conv = nn.Sequential(
            ComplexConv(in_ch, in_ch, 3, stride=down_stride, padding=1),
            RadialBatchNorm(in_ch),  # ComplexBatchNormalize(in_ch),
            Activation(in_ch)  # nn.ReLU(inplace=True)
        )

    def forward(self, x):
        x = self.conv(x)
        return x


class inconv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(inconv, self).__init__()
        self.conv = double_conv(in_ch, out_ch)

    def forward(self, x):
        x = self.conv(x)
        return x

class bnorm(nn.Module):
    def __init__(self, in_ch):
        super(bnorm, self).__init__()
        self.norm = RadialBatchNorm(in_ch)
    def forward(self, x):
        x = self.norm(x)
        return x


class down(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(down, self).__init__()
        self.down_conv = down_conv(in_ch)
        self.double_conv = double_conv(in_ch, out_ch)

    def forward(self, x):
        down_x = self.down_conv(x)
        x = self.double_conv(down_x)
        return x, down_x


class bottleneck(nn.Module):
    def __init__(self, in_ch, out_ch, residual_connection=True):
        super(bottleneck, self).__init__()
        self.residual_connection = residual_connection
        self.down_conv = down_conv(in_ch)
        self.double_conv = nn.Sequential(
            ComplexDropout2d(params.dropout_ratio),
            ComplexConv(in_ch, 2 * in_ch, 3, padding=1),
            RadialBatchNorm(2 * in_ch),  # ComplexBatchNormalize(2*in_ch),
            Activation(2 * in_ch),  # nn.ReLU(inplace=True),
            ComplexDropout2d(params.dropout_ratio),
            ComplexConv(2 * in_ch, out_ch, 3, padding=1),
            RadialBatchNorm(out_ch),  # ComplexBatchNormalize(out_ch),
            Activation(out_ch),  # nn.ReLU(inplace=True)
        )

    def forward(self, x):
        down_x = self.down_conv(x)
        if self.residual_connection:
            x = self.double_conv(down_x) + down_x
        else:
            x = self.double_conv(down_x)

        return x


class up(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(up, self).__init__()
        upsample_mode = 'trilinear' if params.network_type == '3D' else 'bilinear'
        upsample_scale_factor = (2, 2, 1) if params.network_type == '3D' else (2, 2)
        self.up = ComplexUpsample(scale_factor=upsample_scale_factor, mode=upsample_mode)

        self.conv = nn.Sequential(
            ComplexConv(in_ch * 2, in_ch, 3, padding=1),
            RadialBatchNorm(in_ch),  # ComplexBatchNormalize(in_ch),
            Activation(in_ch),  # nn.ReLU(inplace=True),
            ComplexConv(in_ch, out_ch, 3, padding=1),
            RadialBatchNorm(out_ch),  # ComplexBatchNormalize(out_ch),
            Activation(out_ch)  # nn.ReLU(inplace=True)
        )

    def forward(self, x1, x2):
        x1 = self.up(x1)
        diffX = x1.size()[2] - x2.size()[2]
        diffY = x1.size()[3] - x2.size()[3]
        x2 = F.pad(x2, (diffX // 2, int(diffX / 2),
                        diffY // 2, int(diffY / 2)))
        x = torch.cat([x2, x1], dim=1)
        x = self.conv(x)
        return x


class mag_phase_combine(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(mag_phase_combine, self).__init__()
        self.conv1d = nn.Sequential(
            ComplexConv(in_ch, out_ch, 1, padding=(0, 0))
        )

    def forward(self, x):
        t = torch.split(x, int(x.size()[2] / 2), dim=2)
        xt = [i for i in t]
        x1 = xt[0]
        x2 = xt[1]
        x = torch.cat([x1, x2], dim=1)
        x = self.conv1d(x)
        return x


class outconv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(outconv, self).__init__()
        self.conv = ComplexConv(in_ch, out_ch, 1)

    def forward(self, x):
        x = self.conv(x)
        return x


class conv(nn.Module):
    def __init__(self, in_ch, out_ch, kernel_size=3, dilation=1, bias=True, groups=1, apply_BN=False, apply_activation=False):
        super(conv, self).__init__()
        self.padding = dilation * (kernel_size - 1) // 2
        self.apply_BN = apply_BN
        layers = []
        layers.append(ComplexConv(in_ch, out_ch, kernel_size, padding=self.padding, dilation=dilation, bias=bias, groups=groups))

        if apply_BN:
            layers.append(RadialBatchNorm(out_ch))

        if apply_activation:
            layers.append(nn.LeakyReLU(negative_slope=0.1, inplace=True))

        self.conv = nn.Sequential(*layers)


    def forward(self, x):
        return self.conv(x)


class conv_3D(nn.Module):
    def __init__(self, in_ch, out_ch, kernel_size=3, dilation=1, bias=True, groups=1, apply_BN=False, apply_activation=False):
        super(conv_3D, self).__init__()
        # self.padding = dilation * (kernel_size - 1) // 2
        self.padding = [dilation * (ks - 1) // 2 for ks in kernel_size]

        self.apply_BN = apply_BN
        layers = []
        layers.append(ComplexConv3d(in_ch, out_ch, kernel_size, padding=self.padding, dilation=dilation, bias=bias, groups=groups))

        if apply_BN:
            layers.append(RadialBatchNorm3d(out_ch))

        if apply_activation:
            layers.append(nn.LeakyReLU(negative_slope=0.1, inplace=True))

        self.conv = nn.Sequential(*layers)


    def forward(self, x):
        return self.conv(x)


class conv_ri(nn.Module):
    def __init__(self, in_ch, out_ch, kernel_size=3, dilation=1, bias=True, groups=1, apply_BN=False, apply_activation=False):
        super(conv_ri, self).__init__()
        self.padding = dilation * (kernel_size[0] - 1) // 2
        self.apply_BN = apply_BN
        layers = []
        layers.append(nn.Conv3d(in_ch, out_ch, kernel_size, padding=(self.padding, self.padding, 0), dilation=dilation, bias=bias, groups=groups))

        if apply_BN:
            layers.append(nn.BatchNorm3d(out_ch))

        if apply_activation:
            layers.append(nn.ReLU(out_ch))

        self.conv = nn.Sequential(*layers)


    def forward(self, x):
        return self.conv(x)

class stacked_3Dconvs_block(nn.Module):
    def __init__(self, in_ch, out_ch, kernel_size=(3,3,3), dilations=1, apply_BN=False, apply_activation=False, bottelneck_ratio=1):
        super(stacked_3Dconvs_block, self).__init__()
        self.conv_stack = nn.Sequential(
            conv_3D(in_ch, np.round(bottelneck_ratio * in_ch), kernel_size=kernel_size, dilation=dilations, apply_BN=apply_BN, apply_activation=apply_activation),
            conv_3D(np.round(bottelneck_ratio * in_ch), np.round(bottelneck_ratio * in_ch), kernel_size=kernel_size, dilation=dilations, apply_BN=apply_BN, apply_activation=apply_activation),
            conv_3D(np.round(bottelneck_ratio * in_ch), out_ch, kernel_size=kernel_size, dilation=dilations, apply_BN=False, apply_activation=False)
        )

    def forward(self, x):
        return self.conv_stack(x)

